#!/bin/bash

CLEARML_NAME=$CLEARML_NAME
NUM_GPU=$NUM_GPU
EXP_NAME=$EXP_NAME

echo "Python version:"
echo `which python`

export PYTHONFAULTHANDLER=1
export OMP_NUM_THREADS=16

export WORLD_SIZE=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l)
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_PORT=51234

export OUTPUT_DIR=results/$EXP_NAME
export LOG_DIR=$OUTPUT_DIR/logs
export RANK=$SLURM_PROCID

echo "OUTPUT_DIR" $OUTPUT_DIR
echo "Distributed training:"
echo MASTER_ADDR $MASTER_ADDR
echo MASTER_PORT $MASTER_PORT
echo RANK $RANK

mkdir -p $OUTPUT_DIR
mkdir -p $LOG_DIR

torchrun \
    --nproc_per_node=$NUM_GPU \
    --nnodes=$WORLD_SIZE:$WORLD_SIZE \
    --rdzv_id=${EXP_NAME} \
    --rdzv_backend=c10d \
    --rdzv_endpoint=$MASTER_ADDR:$((MASTER_PORT+1)) \
    --max_restarts=3 \
    train.py \
    --outputdir "${OUTPUT_DIR}" \
    --datadir "./data_csn/codesearchnet/data_prepared" \
    --clearml_name "${CLEARML_NAME}" \
    --maxlen 512 \
    --modelname "codet5-large" \
    --tasks codesearchnet_nl2tree codesearchnet_tree2nl codesearchnet_maskednodepred codesearchnet_subtreepred \
    --train_bsz 16 \
    --eval_bsz 16 \
    --grad_acc_steps 2 \
    --nepochs 1 \
    --lr "2e-4" \
    --maskrate 0.15 \
    --subsample_ds 20000 \
    --fp16 \
    --use_zero2 > $LOG_DIR/$RANK.log